# Distributed training example with automatic checkpoint/restore
#
# Creates a shared s3 mount across workers that is used to read and write
# checkpoints. Save checkpoints every 10 epochs.
#
# Adapted from michaelzhiluo and michaelvll's examples.
#
# Usage:
#   sky launch -c myclus checkpointed_training.yaml
#
#   # Interrupt training after few 10's of epochs with:
#   sky stop myclus
#
#   # Resume training (and see training resume from previous checkpoint):
#   sky start myclus
#   sky exec myclus checkpointed_training.yaml

name: resnet-distributed-app


resources:
    accelerators: V100
    cloud: aws

num_nodes: 1

file_mounts:
    /checkpoints:
        name: sky-ckpt-romilb
        mode: MOUNT

setup: |
    pip3 install --upgrade pip
    rm -r ./pytorch-distributed-resnet
    git clone https://github.com/romilbhardwaj/pytorch-distributed-resnet.git
    cd pytorch-distributed-resnet && pip3 install -r requirements.txt
    mkdir -p data  && mkdir -p saved_models && cd data && \
    wget -c --quiet https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
    tar -xvzf cifar-10-python.tar.gz
    mkdir -p /checkpoints/torch_ddp_resnet/

run: |
    cd pytorch-distributed-resnet
    git pull

    num_nodes=`echo "$SKYPILOT_NODE_IPS" | wc -l`
    master_addr=`echo "$SKYPILOT_NODE_IPS" | head -n1`
    python3 -m torch.distributed.launch --nproc_per_node=1 \
    --nnodes=$num_nodes --node_rank=${SKYPILOT_NODE_RANK} --master_addr=$master_addr \
    --master_port=8008 resnet_ddp.py --num_epochs 100 --model_dir /checkpoints/torch_ddp_resnet/ \
    --resume --model_filename resnet_distributed-with-epochs.pth
